import torch
import torch.nn as nn

from modules.nclaw.warp import SVD
from einops.layers.torch import Rearrange


    
    
    

def invert_corotated_elasticity(F_batch: torch.Tensor, stress_batch: torch.Tensor) -> dict:

    batch_size = F_batch.shape[0]
    device = F_batch.device
    dtype = F_batch.dtype
    
    E_values = torch.zeros(batch_size, device=device, dtype=dtype)
    nu_values = torch.zeros(batch_size, device=device, dtype=dtype)
    
    I = torch.eye(3, dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1)
    
    U, sigma, V = torch.linalg.svd(F_batch)
    Vh = V.transpose(1, 2)
    
    det_U = torch.linalg.det(U)
    det_V = torch.linalg.det(V)
    
    for i in range(batch_size):
        if det_U[i] < 0:
            U[i, :, 2] = -U[i, :, 2]
        if det_V[i] < 0:
            V[i, :, 2] = -V[i, :, 2]
            Vh[i, 2, :] = -Vh[i, 2, :]
    
    R = torch.bmm(U, Vh)
    
    J = torch.prod(sigma, dim=1)
    
    corotated_deform = F_batch - R
    
    FT_batch = F_batch.transpose(1, 2)
    corotated_part = torch.bmm(corotated_deform, FT_batch)
    
    volume_part = (J*(J-1.0)).view(-1, 1, 1) * I
    
    for i in range(batch_size):
        A = torch.zeros((9, 2), device=device, dtype=dtype)
        b = stress_batch[i].reshape(-1)
        
        corot_vec = corotated_part[i].reshape(-1)
        vol_vec = volume_part[i].reshape(-1)
        
        A[:, 0] = corot_vec 
        A[:, 1] = vol_vec   
        
        solution = torch.linalg.lstsq(A, b.unsqueeze(1))
        x = solution.solution[:2, 0] 
        
        two_mu = x[0]
        lam = x[1]
        
        two_mu = torch.clamp(two_mu, min=1e-6)
        lam = torch.clamp(lam, min=1e-6)
        
        mu = two_mu / 2.0
        
        nu = lam / (2 * (lam + mu))
        
        nu = torch.clamp(nu, min=0.0, max=0.49)
        
        E = 2 * mu * (1 + nu)
        
        E_values[i] = E
        nu_values[i] = nu
    
    valid_mask = (E_values > 0) & (E_values < 1e6) & (nu_values >= 0) & (nu_values < 0.5)
    if valid_mask.sum() > 0:
        valid_E = E_values[valid_mask]
        valid_nu = nu_values[valid_mask]
    else:
        valid_E = E_values
        valid_nu = nu_values
    
    E_mean = valid_E.mean()
    E_std = valid_E.std()
    nu_mean = valid_nu.mean()
    nu_std = valid_nu.std()
    
    mu_mean = E_mean / (2 * (1 + nu_mean))
    lam_mean = E_mean * nu_mean / ((1 + nu_mean) * (1 - 2 * nu_mean))
    
    corotated_stress = 2 * mu_mean * corotated_part
    volume_stress = lam_mean * volume_part
    theoretical_stress = corotated_stress + volume_stress
    
    error = torch.norm((theoretical_stress - stress_batch).reshape(batch_size, -1), dim=1)
    stress_norm = torch.norm(stress_batch.reshape(batch_size, -1), dim=1)
    
    relative_error = (error / (stress_norm + 1e-10)).mean().item()
    
    confidence = 1.0 / (1.0 + relative_error)
    
    return {
        "E_mean": E_mean.item(),
        "E_std": E_std.item(),
        "nu_mean": nu_mean.item(),
        "nu_std": nu_std.item(),
        "relative_error": relative_error,
        "confidence": confidence,
        "E_values": E_values.detach().cpu().numpy(),
        "nu_values": nu_values.detach().cpu().numpy()
    }
    
    
    
    
def invert_corotated_elasticity_my(F: torch.Tensor, stress: torch.Tensor) -> dict:

    svd = SVD()

    #transpose = Rearrange('b d1 d2 -> b d2 d1', d1=3, d2=3)

    F = F.unsqueeze(0)
    U, sigma, Vh = svd(F)
    J = torch.prod(sigma, dim=1).view(-1, 1, 1)
    I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0)
    
    corotated_mat = 2 * torch.matmul(F - torch.matmul(U, Vh), F.transpose(1, 2))
    volume_mat = J * (J - 1) * I
    
    corotated_mat = corotated_mat[0]
    volume_mat = volume_mat[0]
    
    A = torch.stack([corotated_mat.reshape(-1), volume_mat.reshape(-1)], dim=1)
    b = stress.reshape(-1)
    
    solution = torch.linalg.lstsq(A, b).solution
    
    mu, la = solution[0], solution[1]
    
    return {
        "mu": mu,
        "la": la

    }
    
def invert_corotated_elasticity_batched(F: torch.Tensor, stress: torch.Tensor) -> dict:
    svd = SVD()  # Ensure this handles batched inputs

    # Perform SVD decomposition for all batches
    U, sigma, Vh = svd(F)  # U: (B,3,3), sigma: (B,3), Vh: (B,3,3)
    
    # Compute Jacobian determinants
    J = torch.prod(sigma, dim=1).unsqueeze(-1).unsqueeze(-1)  # (B,1,1)
    
    # Create identity matrix with matching batch size and device
    I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0)  # (1,3,3)
    
    # Compute corotated and volumetric components
    corotated_mat = 2 * torch.matmul(F - torch.matmul(U, Vh), F.transpose(1, 2))  # (B,3,3)
    volume_mat = J * (J - 1) * I  # Broadcast to (B,3,3)
    
    # Reshape components for linear system setup
    corotated_flat = corotated_mat.reshape(-1, 9)  # (B,9)
    volume_flat = volume_mat.reshape(-1, 9)        # (B,9)
    
    # Build system matrix A: (B,9,2)
    A = torch.stack([corotated_flat, volume_flat], dim=-1)
    
    print(f"A.shape", A.shape)
    
    # Reshape stress to (B,9)
    b = stress.reshape(-1, 9)
    print(f"b.shape", b.shape)
    
    # Solve batched linear systems
    solutions = torch.linalg.lstsq(A, b).solution  # (B,2)
    
    # Extract parameters and compute statistics
    mu = solutions[:, 0]
    la = solutions[:, 1]
    
    return {
        "mu": torch.mean(mu),
        "mu_var": torch.var(mu, unbiased=True),
        "la": torch.mean(la),
        "la_var": torch.var(la, unbiased=True)
    }
    
    
def solve_e_corotated(F_all: torch.Tensor, stress_all: torch.Tensor) -> dict:
    svd = SVD()  # Ensure this handles batched inputs
    
    batch_size = F_all.shape[0]
    mu_all = torch.zeros(batch_size)
    la_all = torch.zeros(batch_size)
    stress_pred = torch.zeros_like(stress_all)
    
    
    
    for i in range(batch_size):
    
    
        F = F_all[i:i+1, :, :]
        stress = stress_all[i:i+1, :, :]
        U, sigma, Vh = svd(F)
        J = torch.prod(sigma, dim=1).view(-1, 1, 1)
        J.clamp(min=1e-5)
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0)
        
        corotated_mat = 2 * torch.matmul(F - torch.matmul(U, Vh), F.transpose(1, 2))
        volume_mat = J * (J - 1) * I
        
        corotated_mat = corotated_mat[0]
        volume_mat = volume_mat[0]
        
        A = torch.stack([corotated_mat.reshape(-1), volume_mat.reshape(-1)], dim=1)
        b = stress.reshape(-1)
        
        solution = torch.linalg.lstsq(A, b).solution
        
        mu, la = solution[0], solution[1]
        mu_all[i] = mu
        la_all[i] = la
        
        stress_pred[i:i+1, :, :] = mu * corotated_mat + la * volume_mat
    
    
    
    mu_var = torch.var(mu_all, unbiased=True)
    loss = 0.
    epsilon = 1.
    w = 1e-2
    for i in range(batch_size):
        loss += w * (torch.norm(stress_pred[i:i+1, :, :] - stress_all[i:i+1,:,:])) / (epsilon + mu_var.detach()) 
    
    

    
    return {
        "mu": torch.mean(mu_all),
        "mu_var": mu_var,
        "la": torch.mean(la_all),
        "la_var": torch.var(la_all, unbiased=True),
        "loss": loss
    }
    
def solve_p_identity(F_in_all: torch.Tensor, F_out_all: torch.Tensor) -> dict:
    
    batch_size = F_all.shape[0]
    d_all = torch.zeros(batch_size)

    

    for i in range(batch_size):
        d = torch.mean((F_in_all[i] - F_out_all[i])**2)
        d_all[i] = d
        
    
    
    return {
        "d": torch.mean(d_all),
        "d_var": torch.var(d_all, unbiased=True),
    }